#! /bin/env python 
'''
Created on Sep 30, 2012

@author: Shunping Huang
'''

import os
import sys
import gc
import pysam
import argparse as ap
import logging
import multiprocessing as mp
import time

from modtools.mod import Mod
from modtools import tmpmod

from lapels.matefixer import *
from lapels.utils import readableFile, writableFile, validChromList
from lapels import annotator as annotator
import lapels.version


DESC = "A remapper and annotator of in silico (pseudo) genome alignments."
VERBOSITY = 1
VERSION = '0.0.5'
PKG_VERSION = lapels.version.__version__

nReadsInChroms = dict()
outPrefix = None
outHeader = None
logger = None


def validTagPrefix(s):
    if len(s) != 1:
        msg = "Tag Prefix '%s' should have exactly one character." % s
        raise ap.ArgumentTypeError(msg)
    
    if (s == 'X' or s == 'Y' or s== 'Z' or s[0].islower()):
        return s
    else:
        msg = "Tag Prefix '%s' is not valid/reserved for local use." % s
        raise ap.ArgumentTypeError(msg)


def initLogger():
    global logger
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    formatter = logging.Formatter("[%(asctime)s] %(name)-10s: %(levelname)s: %(message)s",
                                  "%Y-%m-%d %H:%M:%S")
    ch.setFormatter(formatter)
    logger.addHandler(ch)


def annotate(bamfile, tmpmod, outChrom, mergePool, lock=None):
    global nReadsInChroms
    global outPrefix
    global outHeader    
    
    gc.disable()
    if lock:
        lock.acquire()
    logger.info("processing chromosome '%s'", outChrom)
    if lock:
        lock.release()
    
    inFile = pysam.Samfile(bamfile, 'rb')             
    chrom = chromAliases.getBasicName(outChrom)
    
    mod = Mod(tmpmod)
    modChrom = chromAliases.getMatchedAlias(chrom, mod.chroms)
    if modChrom is None:        
        if lock:
            lock.acquire()
        logger.info("alias not found for '%s' in MOD", outChrom)
        if lock:
            lock.release() 
        modChrom = str(chrom)
    else:
        if lock:
            lock.acquire()
        logger.info("alias '%s' used for '%s' in MOD", modChrom, outChrom)
        if lock:
            lock.release()              
    
    mod.load(modChrom)                
            
    bamChrom = chromAliases.getMatchedAlias(chrom, nReadsInChroms.keys())
    if bamChrom is None:
        raise ValueError("Unable to determine the name of '%s' in BAM." 
                         % outChrom)
    else:   
        if lock:         
            lock.acquire()
        logger.info("alias '%s' used for '%s' in BAM", bamChrom, outChrom)
        if lock:
            lock.release()
        bamIter = inFile.fetch(bamChrom)
        nReads = nReadsInChroms[bamChrom]                    
    
    unsortedFileName = "%s.%s.unsorted.bam" % (outPrefix, outChrom)
    tmpFile=pysam.Samfile(unsortedFileName, 'wb', header=outHeader, 
                          referencenames=inFile.references)
            
    a = annotator.Annotator(modChrom, mod.meta.getChromLength(chrom), mod, 
                            bamIter, nReads, tagPrefixes, tmpFile, lock)
    a.execute()
    tmpFile.close()
    inFile.close()
    
    # Sort by read name, required by fixmate
    if lock:
        lock.acquire()
    logger.info("sorting reads in '%s' by names ...", outChrom)
    if lock:
        lock.release()
    sortedFileName = unsortedFileName.replace('unsorted','sorted')        
    pysam.sort('-n', unsortedFileName, sortedFileName[:-4])
    os.remove(unsortedFileName)
    mergePool.append(sortedFileName)        
    gc.enable()
         
         
def worker(workerId, bamfile, tmpmod, queue, mergePool, lock):    
    global qidx
    qlen = len(queue)
    while True:
        if qidx.value == qlen:
            return
        else:
            idx = qidx.value
            qidx.value += 1
        annotate(bamfile, tmpmod, queue[idx], mergePool, lock)    

    

if __name__ == '__main__':
    initLogger()
    
    # Parse arguments
    p = ap.ArgumentParser(description=DESC, 
                          formatter_class = ap.RawTextHelpFormatter)
    p.add_argument('-V', '--version', action='version', version='%(prog)s' + \
                   ' %s in Lapels %s' % (VERSION, PKG_VERSION))
    group = p.add_mutually_exclusive_group()    
    group.add_argument("-q", dest='quiet', action='store_true',
                       help='quiet mode')  
    group.add_argument('-v', dest='verbosity', action="store_const", const=2,
                       default=1, help="verbose mode")
    p.add_argument('-t', dest='keepTemp', action='store_true',
                   help="keep temporary files (default: no)")
    p.add_argument('-n', dest='sortByName', action='store_true',
                   help='output bam file sorted by read names (default: no)')
    p.add_argument('-p', metavar='nProcesses', dest='nProcesses', 
                   type= int, default = 1, 
                   help='number of processes to run (default: 1)')    
    p.add_argument('-c', metavar='chromList', dest='chroms', 
                   type=validChromList, default = set(),                   
                   help='a comma-separated list of chromosomes (default: all)')    
    p.add_argument('--ts', metavar='prefix', type=validTagPrefix, default='s',
                   help='tag prefix for numbers of observed SNPs (default: s)')
    p.add_argument('--ti', metavar='prefix', type=validTagPrefix, default='i',
                   help='tag prefix for numbers of bases in observed insertions'
                        +' (default: i)')
    p.add_argument('--td', metavar='prefix', type=validTagPrefix, default='d',
                   help='tag prefix for numbers of bases in observed deletions' 
                        +' (default: d)')
    p.add_argument('inMod', metavar='in.mod', type=readableFile, 
                   help='the mod file of the in silico genome')
    p.add_argument('inBam', metavar='in.bam', type=readableFile,
                   help='the input bam file')
    p.add_argument('outBam', metavar='out.bam', nargs = '?', type=writableFile, 
                   default=None, help='the output bam file'\
                        +' (default: input.annotated.bam)')
    args = p.parse_args()
    
    if args.quiet:                
        logger.setLevel(logging.CRITICAL)
    elif args.verbosity == 2:                    
        logger.setLevel(logging.DEBUG)
    
    VERBOSITY = args.verbosity
    annotator.VERBOSITY = VERBOSITY
            
    # Check if input bam index exists.
    if not os.path.isfile(args.inBam+'.bai'):        
        logger.info("creating index for input BAM file")
        pysam.index(args.inBam)
        if os.path.isfile(args.inBam+'.bai'):                
            logger.info("index created")
        else:
            logger.exception("index failed")
            raise IOError("Failed to create index. " +
                          "Please make sure the bam file is sorted by position.")
        
    if args.outBam is None:        
        if not args.inBam.endswith('.bam'):
            outFileName = args.inBam + '.annotated.bam'
        else:
            outFileName = args.inBam.replace('.bam','.annotated.bam')
    else:
        outFileName = args.outBam
    
    outPrefix = outFileName[:outFileName.rindex('.bam')]
    tagPrefixes = [args.ts, args.ti, args.td]
        
    logger.info("input MOD file: %s", args.inMod)
    logger.info("input BAM file: %s" % args.inBam)
    logger.info("output BAM file: %s" % outFileName)
            
    # A compromise: adding complexity but reducing unnecessary argument.    
    tmpmod = tmpmod.getTabixMod(args.inMod)
    
    mod = Mod(tmpmod)
    chromAliases = mod.meta.chromAliases
        
    chroms = args.chroms    
    if len(args.chroms) == 0: 
        chroms = mod.chroms
                
    # Get the number of reads in each chromosome
    nReadsInChroms = dict()
    for idxstat in pysam.idxstats(args.inBam):        
        tup = idxstat.rstrip('\n').split('\t')
        nReadsInChroms[tup[0]] = int(tup[2])
            
    inFile = pysam.Samfile(args.inBam, 'rb')
    outHeader = dict(inFile.header.items())
    inFile.close()
    
    # Append a PG tag in the header of output bam
    try:    
        outHeader['PG'] = [{'ID': 'Lapels', 'VN': PKG_VERSION,
                            'PP': outHeader['PG'][0]['ID'],
                            'CL': ' '.join(sys.argv)}] + outHeader['PG']
    except KeyError:
        # If there is no 'PG' tag, add a new one
        outHeader['PG'] = [{'ID': 'Lapels', 'VN': PKG_VERSION, 
                            'CL': ' '.join(sys.argv)}]
    
    # Correct reference lengths in the header.
    for chrDict in outHeader['SQ']:
        sn = chrDict['SN']
        chrDict['LN'] = mod.meta.getChromLength(sn)
        if chrDict['LN'] is None:
            raise ValueError("Unable to find the length of %s in bam." % 
                             chrDict['SN'])                            
    
    if args.sortByName:
        outHeader['HD']['SO'] = 'query_name'
    else:
        # The output bam file will be sorted by position.
        outHeader['HD']['SO'] = 'coordinate'

#    comment = generateComment()
#    outHeader['CO'] = [comment] + outHeader.get('CO',[])                

    # Annotate using multiple processes or a signle process
    # The output is sorted by names
    nProcesses = args.nProcesses
    if VERBOSITY > 1:
        logger.warning("force to use a single process in verbose mode")
        nProcesses = 1        
        
    if nProcesses > 1:
        logger.info("use multiple processes: %d", nProcesses)
        lock = mp.Lock()
        manager = mp.Manager()    
        mergePool = manager.list()
        qidx = mp.Value('i',0)
        try:
            for i in range(nProcesses):            
                p=mp.Process(target=worker, 
                             args=(i, args.inBam, tmpmod, chroms, mergePool, lock))                                      
                p.start()    
        except:        
            raise RuntimeError("Cannot use multiple processes.") 
            
        while len(mp.active_children()) > 1:                
            time.sleep(1)        
    else:
        logger.info("use a single process")
        mergePool = []
        for outChrom in chroms:
            annotate(args.inBam, tmpmod, outChrom, mergePool)

    nMerges = len(mergePool)
    assert nMerges > 0
    if nMerges > 1:
        # Merge
        logger.info("merging %d files ...", nMerges)
        mergeParams = ['-f','-n', outPrefix + '.merged.bam'] + \
                      [fn for fn in mergePool] 
        pysam.merge(*mergeParams)
        if not args.keepTemp:
            for fn in mergePool:
                os.remove(fn)
    else:
        os.rename(mergePool[0], outPrefix + '.merged.bam')
    
    # Fix mates
    logger.info("fixing mate ...")    
#    pysam.fixmate(outPrefix+'.sorted.tmp.bam', outPrefix+'.matefixed.tmp.bam')
    fixmate(outPrefix+'.merged.bam', outPrefix+'.matefixed.bam')
    if not args.keepTemp:
        os.remove(outPrefix+'.merged.bam')
    
    if not args.sortByName:
        # Sort by position
        logger.info("sorting reads by positions ...")
        pysam.sort(outPrefix+'.matefixed.bam', outPrefix)
        if not args.keepTemp:    
            os.remove(outPrefix+'.matefixed.bam')
        
        # Build index for output
        logger.info("creating bam index for output")
        pysam.index(outFileName)
        if os.path.isfile(outFileName+'.bai'):    
            logger.info("index created")
        else:
            logger.warning("index failed")
    else:
        os.rename(outPrefix+'.matefixed.bam', outFileName)

    os.remove(mod.fileName)
    os.remove(mod.fileName+'.tbi')
    
    logger.info("All Done!")
    logging.shutdown()